Example 1: Bayesian filtering of cardiac volatility#

%%capture
import sys
if 'google.colab' in sys.modules:
    ! pip install pyhgf systole
from pyhgf.distribution import HGFDistribution
from pyhgf.model import HGF
import numpy as np
import pymc as pm
import arviz as az
import matplotlib.pyplot as plt
import seaborn as sns
from systole.detection import ecg_peaks
from systole.utils import input_conversion
from systole import import_dataset1
from systole.plots import plot_raw
from bokeh.io import output_notebook
from bokeh.plotting import show

output_notebook()
Loading BokehJS ...

The nodalized version of the Hierarchical Gaussian Filter that is implemented in pyhgf opens the possibility to create filters with multiple inputs. Here, we illustrate how we can use this feature to create an agent that is filtering their physiological signals in real-time. We use a two-level Hierarchical Gaussian Filter to predict the dynamics of the instantaneous heart rate (the RR interval measured at each heartbeat). We then extract the trajectory of surprise at each predictive node to relate it with the cognitive task performed by the participant while the signal is being recorded.

Loading and preprocessing physiological recording#

We use the physiological dataset included in Systole as an example. This recording contains electrocardiography (ECG) and respiration recording.

# Import PPG recording as pandas data frame
physio_df = import_dataset1(modalities=['ECG', 'Respiration'])

# Only use the first 60 seconds for demonstration
ecg = physio_df.ecg
  0%|          | 0/2 [00:00<?, ?it/s]
Downloading ECG channel:   0%|          | 0/2 [00:00<?, ?it/s]
Downloading ECG channel:  50%|█████     | 1/2 [00:01<00:01,  1.20s/it]
Downloading Respiration channel:  50%|█████     | 1/2 [00:01<00:01,  1.20s/it]
Downloading Respiration channel: 100%|██████████| 2/2 [00:02<00:00,  1.30s/it]
Downloading Respiration channel: 100%|██████████| 2/2 [00:02<00:00,  1.29s/it]

Plot the signal with instantaneous heart rate derivations#

show(
    plot_raw(ecg, modality='ecg', sfreq=1000, show_heart_rate=True, backend="bokeh")
)

Preprocessing#

# detect R peaks using Pan-Tomkins algorithm
_, peaks = ecg_peaks(physio_df.ecg)

# convert the peaks into a RR time series
rr = input_conversion(x=peaks, input_type="peaks", output_type="rr_s")

Model#

Note

Here we use the total Gaussian surprise (pyhgf.response.total_gaussian_surprise) as a response function. This response function deviates from the default behaviour for the continuous HGF in that it returns the sum of the surprise for all the probabilistic nodes in the network, whereas the default (pyhgf.response.first_level_gaussian_surprise) only computes the surprise at the first level (i.e. the value parent of the continuous input node). We explicitly specify this parameter here to indicate that we want our model to minimise its prediction errors over all variables, and not only at the observation level. In this case, however, the results are expected to be very similar between the two methods.

from pyhgf.response import total_gaussian_surprise
hgf_logp_op = HGFDistribution(
    n_levels=2,
    model_type="continuous",
    input_data=[rr],
    response_function=total_gaussian_surprise,
)
with pm.Model() as three_level_hgf:

    # omegas priors
    omega_2 = pm.Normal("omega_2", -2.0, 2.0)

    # HGF distribution
    pm.Potential("hgf_loglike", hgf_logp_op(omega_1=-4.0, omega_2=omega_2))
pm.model_to_graphviz(three_level_hgf)
../_images/c8b18ba29046d0bb7d51281956e3dc3f1c82097ded8e4aefd11f8dbf146a3d89.svg
with three_level_hgf:
    idata = pm.sample(chains=2)
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1696248914.100556    3277 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Sequential sampling (2 chains in 1 job)
NUTS: [omega_2]
100.00% [2000/2000 00:11<00:00 Sampling chain 0, 0 divergences]
100.00% [2000/2000 00:08<00:00 Sampling chain 1, 0 divergences]
Sampling 2 chains for 1_000 tune and 1_000 draw iterations (2_000 + 2_000 draws total) took 20 seconds.
We recommend running at least 4 chains for robust computation of convergence diagnostics
az.plot_trace(idata)
plt.tight_layout()
../_images/48db06621ec94591e9e8db9a7d8f2d9a771e0d1a24f7b5881bafe5bbae34dc53.png
# retrieve the best fir for omega_2
omega_2 = az.summary(idata)["mean"]["omega_2"]
hgf = HGF(
    n_levels=2,
    model_type="continuous",
    initial_mu={"1": rr[0], "2": -4.0},
    initial_pi={"1": 1e4, "2": 1e1},
    omega={"1": -4.0, "2": omega_2},
    rho={"1": 0.0, "2": 0.0},
    kappas={"1": 1.0}).input_data(input_data=rr)
Creating a continuous Hierarchical Gaussian Filter with 2 levels.
... Create the update sequence from the network structure.
... Create the belief propagation function.
... Cache the belief propagation function.
Adding 1935 new observations.
hgf.plot_trajectories();
../_images/b5699c0c2173f59687d0ae400bed850c501d1a79d08a5c6c19db0306413c5e42.png

System configuration#

%load_ext watermark
%watermark -n -u -v -iv -w -p pyhgf,jax,jaxlib
Last updated: Mon Oct 02 2023

Python implementation: CPython
Python version       : 3.10.13
IPython version      : 8.16.1

pyhgf : 0.0.10
jax   : 0.4.16
jaxlib: 0.4.16

sys       : 3.10.13 (main, Aug 28 2023, 08:28:42) [GCC 11.4.0]
numpy     : 1.22.0
matplotlib: 3.8.0
pymc      : 5.8.2
arviz     : 0.16.1
seaborn   : 0.13.0

Watermark: 2.4.3